# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Encodec Model config
"""
import math
from typing import Optional
import numpy as np


from ...configuration_utils import PretrainedConfig


ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "encodec_24khz": "https://hf-mirror.com/encodec_24khz/resolve/main/config.json",
    "encodec_48khz": "https://hf-mirror.com/encodec_48khz/resolve/main/config.json",
}

__all__ = ["EncodecConfig"]


class EncodecConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of an [`EncodecModel`]. It is used to instantiate a
    Encodec model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the
    [facebook/encodec_24khz](https://hf-mirror.com/facebook/encodec_24khz) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        target_bandwidths (`List[float]`, *optional*, defaults to `[1.5, 3.0, 6.0, 12.0, 24.0]`):
            The range of diffent bandwiths the model can encode audio with.
        sampling_rate (`int`, *optional*, defaults to 24000):
            The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
        audio_channels (`int`, *optional*, defaults to 1):
            Number of channels in the audio data. Either 1 for mono or 2 for stereo.
        normalize (`bool`, *optional*, defaults to `False`):
            Whether the audio shall be normalized when passed.
        chunk_length_s (`float`, *optional*):
            If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
        overlap (`float`, *optional*):
            Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
            formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
        hidden_size (`int`, *optional*, defaults to 128):
            Intermediate representation dimension.
        num_filters (`int`, *optional*, defaults to 32):
            Number of convolution kernels of first `EncodecConv1d` down sampling layer.
        num_residual_layers (`int`,  *optional*, defaults to 1):
            Number of residual layers.
        upsampling_ratios (`Sequence[int]` , *optional*, defaults to `[8, 5, 4, 2]`):
            Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
            will use the ratios in the reverse order to the ones specified here that must match the decoder order.
        norm_type (`str`, *optional*, defaults to `"weight_norm"`):
            Normalization method. Should be in `["weight_norm", "time_group_norm"]`
        kernel_size (`int`, *optional*, defaults to 7):
            Kernel size for the initial convolution.
        last_kernel_size (`int`, *optional*, defaults to 7):
            Kernel size for the last convolution layer.
        residual_kernel_size (`int`, *optional*, defaults to 3):
            Kernel size for the residual layers.
        dilation_growth_rate (`int`, *optional*, defaults to 2):
            How much to increase the dilation with each layer.
        use_causal_conv (`bool`, *optional*, defaults to `True`):
            Whether to use fully causal convolution.
        pad_mode (`str`, *optional*, defaults to `"reflect"`):
            Padding mode for the convolutions.
        compress (`int`, *optional*, defaults to 2):
            Reduced dimensionality in residual branches (from Demucs v3).
        num_lstm_layers (`int`, *optional*, defaults to 2):
            Number of LSTM layers at the end of the encoder.
        trim_right_ratio (`float`, *optional*, defaults to 1.0):
            Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
            equal to 1.0, it means that all the trimming is done at the right.
        codebook_size (`int`, *optional*, defaults to 1024):
            Number of discret codes that make up VQVAE.
        codebook_dim (`int`, *optional*):
            Dimension of the codebook vectors. If not defined, uses `hidden_size`.
        use_conv_shortcut (`bool`, *optional*, defaults to `True`):
            Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
            an identity function will be used, giving a generic residual connection.

    Example:
        ```python
        >>> from transformers import EncodecModel, EncodecConfig
        ...
        >>> # Initializing a "facebook/encodec_24khz" style configuration
        >>> configuration = EncodecConfig()
        ...
        >>> # Initializing a model (with random weights) from the "facebook/encodec_24khz" style configuration
        >>> model = EncodecModel(configuration)
        ...
        >>> # Accessing the model configuration
        >>> configuration = model.config
        ```
    """
    model_type = "encodec"
    #pylint: disable=W0102
    def __init__(
        self,
        target_bandwidths = [1.5, 3.0, 6.0, 12.0, 24.0],
        sampling_rate=24_000,
        audio_channels=1,
        normalize=False,
        chunk_length_s=None,
        overlap=None,
        hidden_size=128,
        num_filters=32,
        num_residual_layers=1,
        upsampling_ratios = [8, 5, 4, 2],
        norm_type="weight_norm",
        kernel_size=7,
        last_kernel_size=7,
        residual_kernel_size=3,
        dilation_growth_rate=2,
        use_causal_conv=True,
        pad_mode="reflect",
        compress=2,
        num_lstm_layers=2,
        trim_right_ratio=1.0,
        codebook_size=1024,
        codebook_dim=None,
        use_conv_shortcut=True,
        **kwargs,
    ):
        """
        Initializes an instance of the EncodecConfig class.
        
        Args:
            self: The instance of the class.
            target_bandwidths (list[float]): List of target bandwidths in kHz. Default is [1.5, 3.0, 6.0, 12.0, 24.0].
            sampling_rate (int): The audio sampling rate in Hz. Default is 24000.
            audio_channels (int): The number of audio channels. Default is 1.
            normalize (bool): Flag indicating whether to normalize the audio. Default is False.
            chunk_length_s (float): The length of audio chunks in seconds. Default is None.
            overlap (float): The overlap ratio between audio chunks. Default is None.
            hidden_size (int): The size of the hidden state in the model. Default is 128.
            num_filters (int): The number of filters in the model. Default is 32.
            num_residual_layers (int): The number of residual layers in the model. Default is 1.
            upsampling_ratios (list[int]): List of upsampling ratios. Default is [8, 5, 4, 2].
            norm_type (str): The type of normalization. Must be either 'weight_norm' or 'time_group_norm'. Default is 'weight_norm'.
            kernel_size (int): The size of the convolutional kernel. Default is 7.
            last_kernel_size (int): The size of the last convolutional kernel. Default is 7.
            residual_kernel_size (int): The size of the residual convolutional kernel. Default is 3.
            dilation_growth_rate (int): The growth rate of dilation in the model. Default is 2.
            use_causal_conv (bool): Flag indicating whether to use causal convolution. Default is True.
            pad_mode (str): The padding mode for convolution. Default is 'reflect'.
            compress (int): The compression factor for audio. Default is 2.
            num_lstm_layers (int): The number of LSTM layers in the model. Default is 2.
            trim_right_ratio (float): The ratio of trimming audio from the right. Default is 1.0.
            codebook_size (int): The size of the codebook. Default is 1024.
            codebook_dim (int): The dimension of the codebook. Default is equal to hidden_size if not provided.
            use_conv_shortcut (bool): Flag indicating whether to use convolution shortcut. Default is True.
        
        Returns:
            None.
        
        Raises:
            ValueError: If norm_type is not 'weight_norm' or 'time_group_norm'.
        
        """
        self.target_bandwidths = target_bandwidths
        self.sampling_rate = sampling_rate
        self.audio_channels = audio_channels
        self.normalize = normalize
        self.chunk_length_s = chunk_length_s
        self.overlap = overlap
        self.hidden_size = hidden_size
        self.num_filters = num_filters
        self.num_residual_layers = num_residual_layers
        self.upsampling_ratios = upsampling_ratios
        self.norm_type = norm_type
        self.kernel_size = kernel_size
        self.last_kernel_size = last_kernel_size
        self.residual_kernel_size = residual_kernel_size
        self.dilation_growth_rate = dilation_growth_rate
        self.use_causal_conv = use_causal_conv
        self.pad_mode = pad_mode
        self.compress = compress
        self.num_lstm_layers = num_lstm_layers
        self.trim_right_ratio = trim_right_ratio
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
        self.use_conv_shortcut = use_conv_shortcut

        if self.norm_type not in ["weight_norm", "time_group_norm"]:
            raise ValueError(
                f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
            )

        super().__init__(**kwargs)

    # This is a property because you might want to change the chunk_length_s on the fly
    @property
    def chunk_length(self) -> Optional[int]:
        r"""
        chunk_length
        """
        if self.chunk_length_s is None:
            return None
        return int(self.chunk_length_s * self.sampling_rate)

    # This is a property because you might want to change the chunk_length_s on the fly
    @property
    def chunk_stride(self) -> Optional[int]:
        r"""
        chunk_stride
        """
        if self.chunk_length_s is None or self.overlap is None:
            return None
        return max(1, int((1.0 - self.overlap) * self.chunk_length))

    @property
    def frame_rate(self) -> int:
        r"""
        frame_rate
        """
        hop_length = np.prod(self.upsampling_ratios)
        return math.ceil(self.sampling_rate / hop_length)

    @property
    def num_quantizers(self) -> int:
        r"""
        num_quantizers
        """
        return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10))
